package modules.Hamming

import chisel3._
import chisel3.util._

import scala.collection.mutable.ArrayBuffer

class HammingEncode extends Module {
  val io = IO(new HammingBaseIO)
  val sourceData = io.data.asBools
  val sourceVec = Hamming.indexInv.map(i => if (i < 0) false.B else sourceData(i))
  val data = Hamming.indexInv.zipWithIndex.map {
    case (v, i) if v < 0 =>
      if (i == 0) {
        sourceVec.reduce(_ ^ _)
      } else {
        sourceVec.zipWithIndex
          .map(x => ((x._2 & (1 << log2Ceil(i))) != 0, x._1))
          .filter(_._1).map(_._2).reduce(_ ^ _)
      }
    case (v, _i) => io.data(v)
  }
  io.hamming := VecInit(data).asTypeOf(io.hamming.cloneType)
  // printf("Enc: src=%b, enc=%b\n", io.data, io.hamming)
}

object Hamming {
  def hammingDataDefault = VecInit(Seq.fill(72)(false.B))

  def hammingToSource(vec: Vec[Bool]): UInt = vec.asUInt

  val indexList = (3 until 72).filter(i => math.pow(2, log2Ceil(i)).toInt != i)
  require(indexList.length == 64)
  val indexInv = (0 until 72).map(i => if (indexList.contains(i)) indexList.indexOf(i) else -1)
}

class HammingBaseIO extends Bundle {
  val data = Input(UInt(64.W))
  val hamming = Output(UInt(72.W))
}

class HammingIO extends HammingBaseIO {
  val error = Input(Bool())
  val fixed = Input(Bool())
}

class HammingDecode extends Module {
  val io = IO(Flipped(new HammingIO))
  val data = VecInit(io.hamming.asBools)
  val checking = (0 until 8).map(n => if (n == 0) data.reduce(_ ^ _) ^ data.head else {
    data.zipWithIndex.filter(i => (i._2 & (1 << (n - 1))) != 0).map(_._1).reduce(_ ^ _)
  })
  val isErr = checking.head
  val checkedIndex = VecInit(checking.slice(1, 8)).asTypeOf(UInt(7.W))
  val fixData = WireInit(data)
  when(isErr) {
    fixData(checkedIndex) := !data(checkedIndex)
  }
  val destVec = Hamming.indexList.map(i => fixData(i))
  io.data := VecInit(destVec).asTypeOf(io.data.cloneType)
  io.error := isErr
  io.fixed := isErr && checkedIndex =/= 0.U
  // printf("Dec: dec=%b, err=%b, idx=%b, fix=%b\n", io.data, isErr, checkedIndex, io.fixed)
}